import hydra
import random
import os
from tqdm import tqdm
import numpy as np
from dataset.sim_1d_no_x import Sim1d_noX
from dataset.high_dimension_2 import High_dim
from utils.utils import get_logger
from utils.writer import Writer
from omegaconf import OmegaConf
from hydra.core.hydra_config import HydraConfig
from utils.utils import set_random_seed
from omegaconf import open_dict
from utils.utils import save_array_to_npy
from utils.data_class import extract_data_from_npz,extract_effect_from_npz
from model.rkhs.Trainer import RKHS_Trainer



def Experiment(cfg):
    logger = get_logger(cfg, os.path.basename(__file__))
    writer = Writer(cfg, "tensorboard")
    os.makedirs(cfg.log.chkpt_dir, exist_ok=True)
    cfg_str = OmegaConf.to_yaml(cfg)
    logger.info("Config:\n" + cfg_str)

    seeds = np.array([1009, 1102, 1656, 1816, 2029,
                    2297, 2533, 2807, 4259, 4379,
                    4388, 4987, 5518, 5654, 5949,
                    7422, 7987, 8455, 9783, 9886])

    ATE_h_list,ATE_q_list,ATE_dr_list = [],[],[]
    if cfg.dataset.type == 'sim1d_no_x':
         treatment, tar = Sim1d_noX.generate_test_effect(-1,2,100)
         for seed in tqdm(seeds):
            # Create dataset
            train_dataset = Sim1d_noX(seed,cfg.dataset.num).generatate_sim(W_miss=False, Z_miss = False)
            test_dataset = Sim1d_noX.generate_test(1000,seed+1)
            rkhs_train = RKHS_Trainer(cfg, train_dataset)

            # Train h
            rkhs_train.fit_h_cv()
            ATE_h = rkhs_train._htest(treatment,test_dataset)
            ATE_h_list.append(ATE_h)  

            # Train q
            rkhs_train.fit_q_cv()
            ATE_q = rkhs_train._qtest(treatment,test_dataset)
            ATE_q_list.append(ATE_q)  


            # DR
            ATE_dr = rkhs_train._drtest(treatment,test_dataset)
            ATE_dr_list.append(ATE_dr) 


    elif cfg.dataset.type == 'high_2':
        treatment, tar = High_dim.generate_test_effect(0,1,100,'quardratic',10,10,100)
        for seed in tqdm(seeds):
            
            High = High_dim(seed,cfg.dataset.num, dim_z=10, dim_w=10, dim_x=100)
            train_dataset = High.generatate_high(False)
            test_dataset = High.generate_test(1000,seed+1,False)

            rkhs_train = RKHS_Trainer(cfg, train_dataset)

            # Train h
            rkhs_train.fit_h_cv()
            ATE_h = rkhs_train._htest(treatment,test_dataset)
            ATE_h_list.append(ATE_h)  

            # Train q
            rkhs_train.fit_q_cv(type='normflow_new')
            ATE_q = rkhs_train._qtest(treatment,test_dataset)
            ATE_q_list.append(ATE_q)  

            # DR
            ATE_dr = rkhs_train._drtest(treatment,test_dataset)
            ATE_dr_list.append(ATE_dr) 
    

    elif cfg.dataset.type == 'ab':
        folder_path_train = 'dataset/abortion/train' 
        folder_path_effect = 'dataset/abortion/test' 
        for seed in tqdm(range(10)):
            seed_str = str(seed)
            train_name = f'main_ab_seed{seed_str}.npz'
            train_path = f'{folder_path_train}/{train_name}'
            train_dataset, test_dataset = extract_data_from_npz(train_path)

            effect_name = f'do_A_ab_seed{seed_str}.npz'
            effect_path = f'{folder_path_effect}/{effect_name}'
            treatment, tar = extract_effect_from_npz(effect_path)
            
            rkhs_train = RKHS_Trainer(cfg, train_dataset)

            # Train h
            rkhs_train.fit_h_cv()
            ATE_h = rkhs_train._htest(treatment,test_dataset)        
            save_array_to_npy(np.array(ATE_h), f"h_seed{seed_str}", cfg)

            rkhs_train.fit_q_cv()
            ATE_q = rkhs_train._qtest(treatment,test_dataset)
            save_array_to_npy(np.array(ATE_q), f"q_seed{seed_str}", cfg)

            ATE_dr = rkhs_train._drtest(treatment,test_dataset)
            save_array_to_npy(np.array(ATE_dr), f"dr_seed{seed_str}", cfg)

    
     
    save_array_to_npy(tar, "Groud_Truth", cfg)
    save_array_to_npy(np.array(ATE_h_list), "ATE_h", cfg)
    save_array_to_npy(np.array(ATE_q_list), "ATE_q", cfg)
    save_array_to_npy(np.array(ATE_dr_list), "ATE_dr", cfg)
   

@hydra.main(version_base="1.2", config_path="config", config_name="default")
def main(hydra_cfg):
    # print(hydra_cfg)
    with open_dict(hydra_cfg):
        hydra_cfg.job_logging_cfg = HydraConfig.get().job_logging
    # random seed
    if hydra_cfg.random_seed is None:
        hydra_cfg.random_seed = random.randint(1, 10000)
    
    set_random_seed(hydra_cfg.random_seed)
    
    Experiment(hydra_cfg)
    

if __name__ == "__main__":
    main()
